# Author: Mark Richardson
# Purpose: Estimate latent policy disagreement

# Load packages

library(bayesplot)
library(dplyr)
library(forcats)
library(ggplot2)
library(rstan)
library(posterior) # Masks packages from rstan and bayesplot
library(stringr)

# Avoid recompilation errors
rstan_options(auto_write = TRUE)

# Load data

load("data/sfgs_2020.RData")

#### Descriptive look at data ####

table(sfgs$pty_dis_freq) # Only 9 Never disagree - need to collapse Never and Rarely

table(sfgs$pty_dis_strg) # 48 No disagreement - sufficient to not collapse

# By agency - frequency

freq <- sfgs %>%
  group_by(mission, dept, office, pty_dis_freq) %>%
  tidyr::drop_na(pty_dis_freq) %>%
  count() %>%
  group_by(office) %>%
  mutate(total = sum(n),
         prop = n / sum(n))

freq_total <- freq %>% distinct(office, total)

freq <- freq %>%
  select(!c(n, total)) %>%
  tidyr::pivot_wider(names_from = pty_dis_freq, values_from = c(prop)) %>%
  select(mission, dept, office, Never, Rarely, Sometimes, Often, Always, everything())

freq <- full_join(freq, freq_total, by = "office")

rm(freq_total)

# By mission - frequency

freq_msn <- sfgs %>%
  group_by(mission, pty_dis_freq) %>%
  tidyr::drop_na(pty_dis_freq) %>%
  count() %>%
  group_by(mission) %>%
  tidyr::pivot_wider(names_from = pty_dis_freq, values_from = n) %>%
  select(mission, Never, Rarely, Sometimes, Often, Always, everything()) %>%
  rowwise() %>%
  mutate(total = sum(c_across(Never:Refused), na.rm = TRUE))

# By agency - strength

strg <- sfgs %>%
  group_by(mission, dept, office, pty_dis_strg) %>%
  tidyr::drop_na(pty_dis_strg) %>%
  count() %>%
  group_by(office) %>%
  mutate(total = sum(n),
         prop = n / sum(n))

strg_total <- strg %>% distinct(office, total)

strg <- strg %>%
  select(!c(n, total)) %>%
  tidyr::pivot_wider(names_from = pty_dis_strg, values_from = c(prop)) %>%
  select(mission, dept, office,
         `No disagreement`,
         `Low intensity disagreement`,
         `Moderate intensity disagreement`,
         `High intensity disagreement`,
         everything())

strg <- full_join(strg, strg_total, by = "office")

rm(strg_total)

# By mission - strength

strg_msn <- sfgs %>%
  group_by(mission, pty_dis_strg) %>%
  tidyr::drop_na(pty_dis_strg) %>%
  count() %>%
  group_by(mission) %>%
  tidyr::pivot_wider(names_from = pty_dis_strg, values_from = n) %>%
  select(mission,
         `No disagreement`,
         `Low intensity disagreement`,
         `Moderate intensity disagreement`,
         `High intensity disagreement`,
         everything()) %>%
  rowwise() %>%
  mutate(total = sum(c_across(`No disagreement`:`High intensity disagreement`), na.rm = TRUE))

##### Prepare data for the model ####

##### Format variables ####

freq_lvls <- levels(sfgs$pty_dis_freq)[1:5] # Get labels for responses that are not DK/Refused
strg_lvls <- levels(sfgs$pty_dis_strg)[1:4]

# Create ordered factors with only non-DK/Refused responses and collapse Rarely/Never frequency responses

sfgs <- sfgs %>%
  mutate(pty_dis_freq_ord = factor(pty_dis_freq, levels = freq_lvls, ordered = TRUE),
         pty_dis_strg_ord = factor(pty_dis_strg, levels = strg_lvls, ordered = TRUE),
         pty_dis_freq_ord = fct_collapse(pty_dis_freq_ord, `Rarely-Never` = c("Rarely", "Never")))

# Accuracy check
table(sfgs$pty_dis_freq, sfgs$pty_dis_freq_ord, useNA = "always")
table(sfgs$pty_dis_strg, sfgs$pty_dis_strg_ord, useNA = "always")

# Convert ordered factors to integers

sfgs <- sfgs %>%
  mutate(pty_dis_freq_int = as.integer(pty_dis_freq_ord),
         pty_dis_strg_int = as.integer(pty_dis_strg_ord))

# Accuracy check
table(sfgs$pty_dis_freq_ord, sfgs$pty_dis_freq_int, useNA = "always")
table(sfgs$pty_dis_strg_ord, sfgs$pty_dis_strg_int, useNA = "always")

#### Prune agencies ####

# Drop "Other" from departments

sfgs_dis <- sfgs %>%
  filter(!str_detect(office, "Other \\("))

# Accuracy check - only "Other" offices dropped
check <- anti_join(sfgs %>% select(office) %>% distinct(),
                   sfgs_dis %>% select(office) %>% distinct())

rm(check)

# Check for responses to freq but not strg or vice versa

table(freq = is.na(sfgs$pty_dis_freq_int), strg = is.na(sfgs$pty_dis_strg_int)) # 49 strg responses with no freq response; 10 freq responses with no strg response

#### Format data for stan - drop respondents that did not answer both questions ####

# Subset to respondents who gave non-DK answers both questions

sfgs_dis_complete <- sfgs_dis %>%
  select(pty_dis_freq_int, pty_dis_strg_int, office, bureau, dept, mission) %>%
  tidyr::drop_na()

sfgs_dis_complete <- sfgs_dis_complete %>%
  arrange(mission, dept, office)

n <- nrow(sfgs_dis_complete)

#### Get office and mission indexes ####

office_index <- sfgs_dis_complete %>%
  distinct(mission, dept, office) %>%
  arrange(mission, dept, office) %>%
  mutate(office_index = 1:n())

mission_index <- sfgs_dis_complete %>%
  distinct(mission) %>%
  arrange(mission) %>%
  mutate(mission_index = 1:n())

# Merge indexes with data

sfgs_dis_complete <- full_join(sfgs_dis_complete, office_index, by = c("mission", "dept", "office"))

sfgs_dis_complete <- full_join(sfgs_dis_complete, mission_index, by = "mission")

nrow(sfgs_dis_complete) == n # no rows added

rm(n)

# Get office index groupings by mission for non-centered parameterization

msn_office_grp <- sfgs_dis_complete %>%
  select(mission_index, office_index) %>%
  distinct() %>%
  arrange(office_index, mission_index)

# List of data for stan

pty_dis_data <- list(N = nrow(sfgs_dis_complete),
                     J = length(unique(sfgs_dis_complete$office_index)),
                     M = length(unique(sfgs_dis_complete$mission_index)),
                     y_freq = sfgs_dis_complete$pty_dis_freq_int,
                     y_strg = sfgs_dis_complete$pty_dis_strg_int,
                     jj = sfgs_dis_complete$office_index,
                     mm = sfgs_dis_complete$mission_index,
                     mu_group = msn_office_grp$mission_index,
                     K_freq = length(unique(sfgs_dis_complete$pty_dis_freq_int)),
                     K_strg = length(unique(sfgs_dis_complete$pty_dis_strg_int)))

#### Estimate the model ####

# Using Stan defaults for generating initial values resulting in combinations of parameters
# for which the log probability (i.e., the model block) evaluated to -Inf (i.e., prob of 0)
# The function below sets over-dispersed initial values for theta_raw and uses the same valid
#initial values across chains for other parameters

sampleInit <- function(n_chains) {
  
  init_out <- list()
  
  for (i in 1:n_chains) { # Get initial values for 4 chains
    
    ch <- list(c_freq = -1:1,
               c_strg = -1:1,
               theta_raw = dplyr::case_when(i == 1 ~  rep(-4, times = pty_dis_data$J),
                                            i == 2 ~  rep(-2, times = pty_dis_data$J),
                                            i == 3 ~  rep( 2, times = pty_dis_data$J),
                                            i == 4 ~  rep( 4, times = pty_dis_data$J)),
               mu = rep(0, times = pty_dis_data$M),
               tau = rep(1, times = pty_dis_data$M),
               delta_freq = 0.5,
               delta_strg = 0.5)
    
    init_out[[i]] <- ch
    
  }
  
  return(init_out)
  
}

# Set number of chains
nc <- 4

# Get initial values with over-dispersed start values for latent disagreement
dis_inits <- sampleInit(n_chains = nc)

pty_dis_fit <- stan(file = "code/pty_dis_non_centered.stan",
                    data = pty_dis_data,
                    warmup = 1000,
                    iter = 4000,
                    chains = nc,
                    cores = nc,
                    init = dis_inits,
                    seed = 1703,
                    refresh = 500,
                    diagnostic_file = "data/pty_dis_diag",
                    control = list(adapt_delta = 0.80,
                                   max_treedepth = 15,
                                   stepsize = 0.001))

#### Check diagnostics ####

get_elapsed_time(pty_dis_fit) / 60^2

check_treedepth(pty_dis_fit)
check_energy(pty_dis_fit)
check_divergences(pty_dis_fit)

stan_diag(pty_dis_fit)
stan_rhat(pty_dis_fit)
stan_ess(pty_dis_fit)
stan_mcse(pty_dis_fit)


# Create a function to calculate diagnostics for model parameters

# Pars is a character string to subset the model parameters
stanDiag <- function(stan_object, pars = NULL) {
  
  # Get model parameters
  if (length(pars) == 0) {
    
    p_smy <- summary(stan_object)$summary
    
  } else {
    
    p_smy <- summary(stan_object, pars = pars)$summary
    
  }
  
  # Get vector of parameters
  p <- rownames(p_smy) 
  
  # Get array of draws
  array_of_draws <- as.array(stan_object)
  
  # Initialize vectors to store diagnostics
  p_r_hat <- vector()
  
  p_ess_bulk <- vector()
  
  p_ess_tail <- vector()
  
  for (i in 1:length(p)) {
    
    p_r_hat[i] <- rstan::Rhat(array_of_draws[,,p[i]])
    
    p_ess_bulk[i] <- rstan::ess_bulk(array_of_draws[,,p[i]])
    
    p_ess_tail[i] <- rstan::ess_tail(array_of_draws[,,p[i]])
    
  }
  
  names(p_r_hat) <- p
  names(p_ess_bulk) <- p
  names(p_ess_tail) <- p
  
  rstan_diags <- list(r_hat = p_r_hat,
                      ess_bulk = p_ess_bulk,
                      ess_tail = p_ess_tail)
  
  return(rstan_diags)
  
}

diag_pty_dis <- stanDiag(pty_dis_fit)

# R-hat - all below 1.05
mcmc_rhat_hist(diag_pty_dis$r_hat)

# ESS to TSS - some below 0.5
mcmc_neff_hist(neff_ratio(pty_dis_fit))

# Bulk and tail ESS

ggplot(data = tibble(ess_bulk = diag_pty_dis$ess_bulk)) +
  geom_histogram(aes(ess_bulk), binwidth = 1000, color = "black", fill = "dodgerblue", alpha = 0.8) +
  geom_vline(xintercept = 400, linetype = "dashed") +
  theme_bw()

min(diag_pty_dis$ess_bulk) # min is > 400

ggplot(data = tibble(ess_tail = diag_pty_dis$ess_tail)) +
  geom_histogram(aes(ess_tail), binwidth = 1000, color = "black", fill = "dodgerblue", alpha = 0.8) +
  geom_vline(xintercept = 400, linetype = "dashed") +
  theme_bw()

min(diag_pty_dis$ess_tail) # min is > 400

# MCSE to Posterior SD

stan_mcse(pty_dis_fit)

#### Look at parameters ####

pty_dis_smy <- summary(pty_dis_fit)$summary 

#### Save model ####

save(pty_dis_fit, sfgs, sfgs_dis_complete, mission_index, office_index, file = "data/01_disagreement_model_fitted.RData")

